{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def padding_sequence(seq, max_len = 501, repkey = 'N'):\n",
    "    seq_len = len(seq)\n",
    "    if seq_len < max_len:\n",
    "        gap_len = max_len -seq_len\n",
    "        new_seq = seq + repkey * gap_len\n",
    "    else:\n",
    "        new_seq = seq[:max_len]\n",
    "    return new_seq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_RNA_seq_concolutional_array(seq, motif_len = 4):\n",
    "    seq = seq.replace('U', 'T')\n",
    "    alpha = 'ACGT'\n",
    "    #for seq in seqs:\n",
    "    #for key, seq in seqs.iteritems():\n",
    "    row = (len(seq) + 2*motif_len - 2)\n",
    "    new_array = np.zeros((row, 4))\n",
    "    for i in range(motif_len-1):\n",
    "        new_array[i] = np.array([0.25]*4)\n",
    "    \n",
    "    for i in range(row-3, row):\n",
    "        new_array[i] = np.array([0.25]*4)\n",
    "        \n",
    "    #pdb.set_trace()\n",
    "    for i, val in enumerate(seq):\n",
    "        i = i + motif_len-1\n",
    "        if val not in 'ACGT':\n",
    "            new_array[i] = np.array([0.25]*4)\n",
    "            continue\n",
    "        #if val == 'N' or i < motif_len or i > len(seq) - motif_len:\n",
    "        #    new_array[i] = np.array([0.25]*4)\n",
    "        #else:\n",
    "        try:\n",
    "            index = alpha.index(val)\n",
    "            new_array[i][index] = 1\n",
    "        except:\n",
    "            pdb.set_trace()\n",
    "        #data[key] = new_array\n",
    "    return new_array\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_bag_data_1_channel(data, max_len = 501):\n",
    "    bags = []\n",
    "    seqs = data[\"seq\"]\n",
    "    labels = data[\"Y\"]\n",
    "    for seq in seqs:\n",
    "        #pdb.set_trace()\n",
    "        #bag_seqs = split_overlap_seq(seq)\n",
    "        bag_seq = padding_sequence(seq, max_len = max_len)\n",
    "        #flat_array = []\n",
    "        bag_subt = []\n",
    "        #for bag_seq in bag_seqs:\n",
    "        tri_fea = get_RNA_seq_concolutional_array(bag_seq)\n",
    "        bag_subt.append(tri_fea.T)\n",
    "\n",
    "        \n",
    "        bags.append(np.array(bag_subt))\n",
    "    \n",
    "        \n",
    "    return bags, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_seq_graphprot(seq_file, label = 1):\n",
    "    seq_list = []\n",
    "    labels = []\n",
    "    seq = ''\n",
    "    with open(seq_file, 'r') as fp:\n",
    "        for line in fp:\n",
    "            if line[0] == '>':\n",
    "                name = line[1:-1]\n",
    "            else:\n",
    "                seq = line[:-1].upper()\n",
    "                seq = seq.replace('T', 'U')\n",
    "                seq_list.append(seq)\n",
    "                labels.append(label)\n",
    "    \n",
    "    return seq_list, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data_file(posifile, negafile = None, train = True):\n",
    "    data = dict()\n",
    "    seqs, labels = read_seq_graphprot(posifile, label = 1)\n",
    "    if negafile:\n",
    "        seqs2, labels2 = read_seq_graphprot(negafile, label = 0)\n",
    "        seqs = seqs + seqs2\n",
    "        labels = labels + labels2\n",
    "        \n",
    "    data[\"seq\"] = seqs\n",
    "    data[\"Y\"] = np.array(labels)\n",
    "    \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(posi, nega = None, channel = 7,  window_size = 101, train = True):\n",
    "    data = read_data_file(posi, nega, train = train)\n",
    "    if channel == 1:\n",
    "        train_bags, label = get_bag_data_1_channel(data, max_len = window_size)\n",
    "\n",
    "    else:\n",
    "        train_bags, label = get_bag_data(data, channel = channel, window_size = window_size)\n",
    "    \n",
    "    return train_bags, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, l = get_data('./CLIPSEQ_AGO2.train.positives.fa', './CLIPSEQ_AGO2.train.negatives.fa', channel=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "f_test, l_test = get_data('./CLIPSEQ_AGO2.ls.positives.fa','CLIPSEQ_AGO2.ls.negatives.fa', channel=1, train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = np.array(f)\n",
    "f = np.swapaxes(f, -1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "f_test = np.array(f_test)\n",
    "f_test = np.swapaxes(f_test, -1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "from keras.datasets import mnist\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Flatten\n",
    "from keras.layers import Conv2D, MaxPooling2D\n",
    "from keras import backend as K\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_1 (Conv2D)            (None, 107, 4, 128)       3968      \n",
      "_________________________________________________________________\n",
      "dropout_1 (Dropout)          (None, 107, 4, 128)       0         \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 35, 4, 128)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 35, 4, 128)        163968    \n",
      "_________________________________________________________________\n",
      "dropout_2 (Dropout)          (None, 35, 4, 128)        0         \n",
      "_________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2 (None, 11, 4, 128)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 11, 4, 256)        164096    \n",
      "_________________________________________________________________\n",
      "dropout_3 (Dropout)          (None, 11, 4, 256)        0         \n",
      "_________________________________________________________________\n",
      "global_average_pooling2d_1 ( (None, 256)               0         \n",
      "_________________________________________________________________\n",
      "dropout_4 (Dropout)          (None, 256)               0         \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 128)               32896     \n",
      "_________________________________________________________________\n",
      "dropout_5 (Dropout)          (None, 128)               0         \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 1)                 129       \n",
      "=================================================================\n",
      "Total params: 365,057\n",
      "Trainable params: 365,057\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Conv2D(128, kernel_size=(10, 3),\n",
    "                 activation='relu',\n",
    "                 input_shape=f.shape[1:], padding='same'))\n",
    "model.add(Dropout(0.25))\n",
    "model.add(MaxPooling2D(pool_size=(3, 1)))\n",
    "model.add(Conv2D(128, (10, 1), activation='relu', padding='same'))\n",
    "model.add(Dropout(0.25))\n",
    "model.add(MaxPooling2D(pool_size=(3, 1)))\n",
    "model.add(Conv2D(256, (5, 1), activation='relu', padding='same'))\n",
    "model.add(Dropout(0.25))\n",
    "model.add(keras.layers.GlobalAveragePooling2D())\n",
    "model.add(Dropout(0.25))\n",
    "# model.add(Flatten())\n",
    "model.add(Dense(128, activation='relu'))\n",
    "model.add(Dropout(0.5))\n",
    "model.add(Dense(1, activation='sigmoid'))\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=keras.losses.binary_crossentropy,\n",
    "              optimizer=keras.optimizers.Adadelta(),\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 92346 samples, validate on 1000 samples\n",
      "Epoch 1/50\n",
      "92346/92346 [==============================] - 10s 112us/step - loss: 0.6875 - acc: 0.5432 - val_loss: 0.6794 - val_acc: 0.5720\n",
      "Epoch 2/50\n",
      "92346/92346 [==============================] - 8s 86us/step - loss: 0.6789 - acc: 0.5752 - val_loss: 0.6701 - val_acc: 0.6140\n",
      "Epoch 3/50\n",
      "92346/92346 [==============================] - 8s 86us/step - loss: 0.6727 - acc: 0.5904 - val_loss: 0.6615 - val_acc: 0.6340\n",
      "Epoch 4/50\n",
      "92346/92346 [==============================] - 8s 85us/step - loss: 0.6663 - acc: 0.5998 - val_loss: 0.6530 - val_acc: 0.6180\n",
      "Epoch 5/50\n",
      "92346/92346 [==============================] - 8s 86us/step - loss: 0.6591 - acc: 0.6053 - val_loss: 0.6554 - val_acc: 0.6380\n",
      "Epoch 6/50\n",
      "92346/92346 [==============================] - 8s 86us/step - loss: 0.6520 - acc: 0.6112 - val_loss: 0.6316 - val_acc: 0.6220\n",
      "Epoch 7/50\n",
      "92346/92346 [==============================] - 8s 85us/step - loss: 0.6472 - acc: 0.6131 - val_loss: 0.6295 - val_acc: 0.6190\n",
      "Epoch 8/50\n",
      "92346/92346 [==============================] - 8s 84us/step - loss: 0.6425 - acc: 0.6179 - val_loss: 0.6274 - val_acc: 0.6290\n",
      "Epoch 9/50\n",
      "92346/92346 [==============================] - 8s 85us/step - loss: 0.6399 - acc: 0.6194 - val_loss: 0.6299 - val_acc: 0.6250\n",
      "Epoch 10/50\n",
      "92346/92346 [==============================] - 8s 86us/step - loss: 0.6368 - acc: 0.6228 - val_loss: 0.6298 - val_acc: 0.6320\n",
      "Epoch 11/50\n",
      "92346/92346 [==============================] - 8s 87us/step - loss: 0.6354 - acc: 0.6230 - val_loss: 0.6280 - val_acc: 0.6230\n",
      "Epoch 12/50\n",
      "92346/92346 [==============================] - 8s 88us/step - loss: 0.6325 - acc: 0.6273 - val_loss: 0.6207 - val_acc: 0.6340\n",
      "Epoch 13/50\n",
      "92346/92346 [==============================] - 8s 86us/step - loss: 0.6309 - acc: 0.6275 - val_loss: 0.6222 - val_acc: 0.6350\n",
      "Epoch 14/50\n",
      "92346/92346 [==============================] - 8s 87us/step - loss: 0.6293 - acc: 0.6307 - val_loss: 0.6297 - val_acc: 0.6200\n",
      "Epoch 15/50\n",
      "92346/92346 [==============================] - 8s 88us/step - loss: 0.6279 - acc: 0.6324 - val_loss: 0.6199 - val_acc: 0.6280\n",
      "Epoch 16/50\n",
      "92346/92346 [==============================] - 8s 88us/step - loss: 0.6256 - acc: 0.6357 - val_loss: 0.6209 - val_acc: 0.6400\n",
      "Epoch 17/50\n",
      "92346/92346 [==============================] - 8s 88us/step - loss: 0.6233 - acc: 0.6371 - val_loss: 0.6239 - val_acc: 0.6230\n",
      "Epoch 18/50\n",
      "92346/92346 [==============================] - 8s 88us/step - loss: 0.6224 - acc: 0.6387 - val_loss: 0.6194 - val_acc: 0.6370\n",
      "Epoch 19/50\n",
      "92346/92346 [==============================] - 8s 88us/step - loss: 0.6192 - acc: 0.6413 - val_loss: 0.6216 - val_acc: 0.6510\n",
      "Epoch 20/50\n",
      "92346/92346 [==============================] - 8s 89us/step - loss: 0.6190 - acc: 0.6431 - val_loss: 0.6148 - val_acc: 0.6490\n",
      "Epoch 21/50\n",
      "92346/92346 [==============================] - 8s 89us/step - loss: 0.6163 - acc: 0.6462 - val_loss: 0.6131 - val_acc: 0.6430\n",
      "Epoch 22/50\n",
      "92346/92346 [==============================] - 8s 89us/step - loss: 0.6133 - acc: 0.6507 - val_loss: 0.6216 - val_acc: 0.6350\n",
      "Epoch 23/50\n",
      "92346/92346 [==============================] - 8s 90us/step - loss: 0.6117 - acc: 0.6518 - val_loss: 0.6102 - val_acc: 0.6420\n",
      "Epoch 24/50\n",
      "92346/92346 [==============================] - 8s 92us/step - loss: 0.6111 - acc: 0.6540 - val_loss: 0.6101 - val_acc: 0.6480\n",
      "Epoch 25/50\n",
      "92346/92346 [==============================] - 8s 91us/step - loss: 0.6081 - acc: 0.6577 - val_loss: 0.6231 - val_acc: 0.6220\n",
      "Epoch 26/50\n",
      "92346/92346 [==============================] - 9s 92us/step - loss: 0.6066 - acc: 0.6576 - val_loss: 0.6128 - val_acc: 0.6380\n",
      "Epoch 27/50\n",
      "92346/92346 [==============================] - 8s 91us/step - loss: 0.6023 - acc: 0.6638 - val_loss: 0.6183 - val_acc: 0.6360\n",
      "Epoch 28/50\n",
      "92346/92346 [==============================] - 8s 89us/step - loss: 0.6015 - acc: 0.6633 - val_loss: 0.6122 - val_acc: 0.6470\n",
      "Epoch 29/50\n",
      "92346/92346 [==============================] - 8s 90us/step - loss: 0.5995 - acc: 0.6672 - val_loss: 0.6291 - val_acc: 0.6310\n",
      "Epoch 30/50\n",
      "92346/92346 [==============================] - 8s 90us/step - loss: 0.5958 - acc: 0.6723 - val_loss: 0.6273 - val_acc: 0.6400\n",
      "Epoch 31/50\n",
      "92346/92346 [==============================] - 8s 90us/step - loss: 0.5930 - acc: 0.6762 - val_loss: 0.6154 - val_acc: 0.6490\n",
      "Epoch 32/50\n",
      "92346/92346 [==============================] - 8s 91us/step - loss: 0.5915 - acc: 0.6758 - val_loss: 0.6152 - val_acc: 0.6440\n",
      "Epoch 33/50\n",
      "92346/92346 [==============================] - 8s 90us/step - loss: 0.5879 - acc: 0.6787 - val_loss: 0.6133 - val_acc: 0.6450\n",
      "Epoch 34/50\n",
      "92346/92346 [==============================] - 8s 92us/step - loss: 0.5860 - acc: 0.6817 - val_loss: 0.6275 - val_acc: 0.6370\n",
      "Epoch 35/50\n",
      "92346/92346 [==============================] - 8s 92us/step - loss: 0.5831 - acc: 0.6843 - val_loss: 0.6135 - val_acc: 0.6530\n",
      "Epoch 36/50\n",
      "92346/92346 [==============================] - 8s 91us/step - loss: 0.5790 - acc: 0.6871 - val_loss: 0.6337 - val_acc: 0.6320\n",
      "Epoch 37/50\n",
      "92346/92346 [==============================] - 8s 91us/step - loss: 0.5767 - acc: 0.6916 - val_loss: 0.6341 - val_acc: 0.6440\n",
      "Epoch 38/50\n",
      "92346/92346 [==============================] - 8s 91us/step - loss: 0.5720 - acc: 0.6960 - val_loss: 0.6247 - val_acc: 0.6550\n",
      "Epoch 39/50\n",
      "92346/92346 [==============================] - 8s 91us/step - loss: 0.5699 - acc: 0.6963 - val_loss: 0.6244 - val_acc: 0.6300\n",
      "Epoch 40/50\n",
      "92346/92346 [==============================] - 8s 89us/step - loss: 0.5671 - acc: 0.7006 - val_loss: 0.6324 - val_acc: 0.6330\n",
      "Epoch 41/50\n",
      "92346/92346 [==============================] - 9s 92us/step - loss: 0.5631 - acc: 0.7017 - val_loss: 0.6827 - val_acc: 0.6060\n",
      "Epoch 42/50\n",
      "92346/92346 [==============================] - 8s 91us/step - loss: 0.5597 - acc: 0.7061 - val_loss: 0.6394 - val_acc: 0.6230\n",
      "Epoch 43/50\n",
      "92346/92346 [==============================] - 8s 91us/step - loss: 0.5567 - acc: 0.7099 - val_loss: 0.6879 - val_acc: 0.6040\n",
      "Epoch 44/50\n",
      "92346/92346 [==============================] - 8s 92us/step - loss: 0.5544 - acc: 0.7093 - val_loss: 0.6370 - val_acc: 0.6210\n",
      "Epoch 45/50\n",
      "92346/92346 [==============================] - 8s 91us/step - loss: 0.5483 - acc: 0.7170 - val_loss: 0.6549 - val_acc: 0.6130\n",
      "Epoch 46/50\n",
      "92346/92346 [==============================] - 8s 90us/step - loss: 0.5454 - acc: 0.7185 - val_loss: 0.6544 - val_acc: 0.6260\n",
      "Epoch 47/50\n",
      "92346/92346 [==============================] - 8s 90us/step - loss: 0.5443 - acc: 0.7187 - val_loss: 0.6463 - val_acc: 0.6290\n",
      "Epoch 48/50\n",
      "92346/92346 [==============================] - 8s 90us/step - loss: 0.5391 - acc: 0.7230 - val_loss: 0.6684 - val_acc: 0.6130\n",
      "Epoch 49/50\n",
      "92346/92346 [==============================] - 8s 90us/step - loss: 0.5348 - acc: 0.7277 - val_loss: 0.6547 - val_acc: 0.6240\n",
      "Epoch 50/50\n",
      "92346/92346 [==============================] - 8s 88us/step - loss: 0.5307 - acc: 0.7313 - val_loss: 0.6688 - val_acc: 0.6030\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7fc5c4a50250>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(f, l,\n",
    "          batch_size=256,\n",
    "          epochs=50,\n",
    "          verbose=1,\n",
    "          validation_data=(f_test, l_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
